%This script runs the simulations used for Figures 2 and 3. 
%Corresponding author: Rodrigo Adao
%Date: 09/11/2024
%Input: model_data_fgkk.mat
%Output: Figures 2 and 3

%% Preliminaries

clear all;
close all;

%Define vector of large varieties in our sample
trade_target = 0.9; %baseline has 0.9

%Simulation parameters
Nsimul = 1500; %Number of simulations (baseline 1500)

%% Import Data
%set local paths
function_path = "..\Functions\";
addpath(function_path,'-frozen');
set_local_paths;

%Import data and compute researcher predictions
simulation_preliminary_steps;

%% Create IV matrices 

 %Naive IV matrices
    var_adj=sum(ww_dW.^2)/Nobs;
    share_NC_nGE  = shareIV_tau;
    share_wNC_nGE  = ww_dWmat_naive*share_NC_nGE; 
    share_wMC_nGE  = ( share_wNC_nGE - Cn*(MCn_term1*share_wNC_nGE) )/var_adj;
    clearvars share_wNC_nGE share_NC_nGE shareIV_tau

%Auxiliary matrices for control and welfare adjustments
%load auxiliary adjustment vectors
load(save_share_path + "simulation_shares_t" + trade_target + ".mat", 'adjGE_wMC')
    
    adjGE_wMC_mat= spdiags(adjGE_wMC, 0 , Nobs, Nobs );
    share_wMC_dWa = adjGE_wMC_mat*shareMOD_dW;
    share_wMC_dW = share_wMC_dWa - C*(MC_term1*share_wMC_dWa);
    clearvars share_MC_dWa share_wMC_dWa adjGE_MC_mat adjGE_wMC_mat ww_dWmat ww_dWmat_naive

%% Set up calibration for simulations

% Tariff shocks    
%    w=[ww_vm;ww_px];
%    shock = [dtau_vec(ind_Msample_ig ); dtaustar_vec(ind_Xsample_ig )];
%    moments_shocks = [mean(shock(shock~=0)), std(shock(shock~=0))];

%Specify vector st dev of residuals used in Figure 2
sd_shock_grid_plot = linspace(.02, .3, 9)';
Nsd = length(sd_shock_grid_plot);

%Shock parameters
avg_sh = .02; %average tariff shock
std_sh = .06; %st dev of tariff shock
std_ep = .06; %st dev of other shocks used in Figure 3

%Specify misspecification vector used for Figure 3
gamma_grid_plot = [-1.5, -1, -.5, -.25, 0, .25, .5,  1, 1.5]';
Ngamma = length(gamma_grid_plot);

%Parameters for simulation draws
sd_shifters_grid   = std_sh*ones(Nsd+Ngamma,1);
mean_shifters_grid = avg_sh*ones(Nsd+Ngamma,1);

%parameters for draws of shocks to other parameters from normal distribution
mean_deta_grid     = 0*ones(Nsd+Ngamma,1);
sd_da_grid         = [sd_shock_grid_plot; std_ep*ones(Ngamma,1)]; 
sd_dzstar_grid     = [sd_shock_grid_plot; std_ep*ones(Ngamma,1)]; 
sd_dastar_grid     = [sd_shock_grid_plot; std_ep*ones(Ngamma,1)]; 

%Parameters controlling true DGP
omega_star_grid = omega_star*ones(Nsd+Ngamma,1);
sigma_star_grid = sigma_star*ones(Nsd+Ngamma,1);
sigma_grid      = sigma     *ones(Nsd+Ngamma,1);       
eta_grid        = eta       *ones(Nsd+Ngamma,1);        
kappa_grid      = kappa     *ones(Nsd+Ngamma,1);       
rho_exp_grid    = 1         *ones(Nsd+Ngamma,1);
rho_imp_grid    = 1         *ones(Nsd+Ngamma,1);
gammaM_grid     = 1         *ones(Nsd+Ngamma,1);
gammaX_grid     = 1         *ones(Nsd+Ngamma,1);
gammaQ_grid     = 1         *ones(Nsd+Ngamma,1);

gamma_mean_px_grid = [0*ones(Nsd,1); gamma_grid_plot];
gamma_mean_pm_grid = [0*ones(Nsd,1); gamma_grid_plot];
gamma_mean_rm_grid = [0*ones(Nsd,1); gamma_grid_plot];

gamma_sd_px_grid = zeros(size(gamma_mean_px_grid));
gamma_sd_pm_grid = zeros(size(gamma_mean_pm_grid));
gamma_sd_rm_grid = zeros(size(gamma_mean_rm_grid));

Npar = Nsd+Ngamma;

%% Compute equilibrium given different draws of parameters

Npar
ests = zeros(14, Nsimul, Npar);

for j=1:Npar
    tic
    display('------------- running new j -----------------')
    j
    
    %shock process
    sd_shifters  = sd_shifters_grid(j);
    mean_shifters = mean_shifters_grid(j);
    adj = Nobs*(mean_shifters/sd_shifters);

    mean_deta = mean_deta_grid(j);
    sd_da     = sd_da_grid(j);
    sd_dzstar = sd_dzstar_grid(j);
    sd_dastar = sd_dastar_grid(j);

    %Auxiliary matrices for computing IV       
    share_wMC_dW_j = adj*share_wMC_dW;
    
    %Generate true DGP
    omega_starDGP = omega_star_grid(j);    
    sigmaDGP = sigma_grid(j);           
    etaDGP = eta_grid(j);               
    kappaDGP = kappa_grid(j);           
    sigma_starDGP = sigma_star_grid(j); 
    rhoMDGP = rho_imp_grid(j) ;
    rhoXDGP = rho_exp_grid(j) ;
    gammaMDGP = gammaM_grid(j) ;
    gammaXDGP = gammaX_grid(j) ;
    gammaQDGP = gammaQ_grid(j) ;    
    
    gamma_n = (indM==1).*( normrnd(gamma_mean_pm_grid(j), gamma_sd_pm_grid(j), Nobs, 1) )... 
            + (indT==1).*( normrnd(gamma_mean_rm_grid(j), gamma_sd_rm_grid(j), Nobs, 1) )...
            + (indX==1).*( normrnd(gamma_mean_px_grid(j), gamma_sd_px_grid(j), Nobs, 1) );
    gammaDGP = spdiags(1+ gamma_n, 0, Nobs,Nobs);

     [delta_ig_DGP0, a_star_ig_DGP0, z_star_ig_DGP0, a_ig_DGP0, aM_g_DGP0, AM_s_DGP0, AD_s_DGP0, betaT_DGP0, beta_s_DGP0, alphaL_s_DGP0, alphaI_s_DGP0, alpha_ks_DGP0, barL_rs_DGP0, Z_rs_DGP0, D_DGP0, E_s_DGP0, F_DGP0] = ... 
        invert_param_full(x_ig_0, m_ig_0, tau_star_ig_0, tau_ig_0, trade_con_share, tot_labor_comp, tot_intermediate, sales_s, IO_sales, Lsr, ...
        Dsg, DX, DM, omega_starDGP, sigmaDGP, etaDGP, kappaDGP, sigma_starDGP, gammaMDGP, gammaXDGP,  tol_parm);
 
    [p_s_IV, PM_s_IV, ...
         rhoDGP_qm_ig_tau_ig, rhoDGP_qm_ig_taustar_ig, rhoDGP_qm_ig_a_ig, rhoDGP_qm_ig_zstar_ig, rhoDGP_qm_ig_astar_ig, rhoDGP_qm_ig_delta_ig, ...
         rhoDGP_pm_ig_tau_ig, rhoDGP_pm_ig_taustar_ig, rhoDGP_pm_ig_a_ig, rhoDGP_pm_ig_zstar_ig, rhoDGP_pm_ig_astar_ig, rhoDGP_pm_ig_delta_ig, ...
         rhoDGP_px_ig_tau_ig, rhoDGP_px_ig_taustar_ig, rhoDGP_px_ig_a_ig, rhoDGP_px_ig_zstar_ig, rhoDGP_px_ig_astar_ig, rhoDGP_px_ig_delta_ig, ...
         rhoDGP_pstar_ig_tau_ig, rhoDGP_pstar_ig_taustar_ig, rhoDGP_pstar_ig_a_ig, rhoDGP_pstar_ig_zstar_ig, rhoDGP_pstar_ig_astar_ig, rhoDGP_pstar_ig_delta_ig, ...
         rhoDGP_rm_ig_tau_ig, rhoDGP_rm_ig_taustar_ig, rhoDGP_rm_ig_a_ig, rhoDGP_rm_ig_zstar_ig, rhoDGP_rm_ig_astar_ig, rhoDGP_rm_ig_delta_ig] ...
         = compute_equilibrium_foa_full( delta_ig_DGP0, a_star_ig_DGP0, z_star_ig_DGP0, a_ig_DGP0, aM_g_DGP0, AM_s_DGP0, AD_s_DGP0, betaT_DGP0, beta_s_DGP0, alphaL_s_DGP0, alphaI_s_DGP0, alpha_ks_DGP0, barL_rs_DGP0, Z_rs_DGP0, D_DGP0, ... 
             tau_ig_0, tau_star_ig_0, m_ig_0, E_s_DGP0, p_s_0, Dsg, DX, DM, omega_starDGP, sigmaDGP, etaDGP, kappaDGP, sigma_starDGP, nu, epsilon, gammaQDGP, gammaXDGP, gammaMDGP, rhoXDGP, rhoMDGP, tol_conv, adj_inner_g0, adj_outter_g0, ... 
             ind_Mshifts_vec, ind_Xshifts_vec, ind_Msample_ig, ind_Xsample_ig ); 

    shareDGP_pm      = [rhoDGP_pm_ig_tau_ig    , rhoDGP_pm_ig_taustar_ig   ];
    shareDGP_rm      = [rhoDGP_rm_ig_tau_ig    , rhoDGP_rm_ig_taustar_ig   ];
    shareDGP_px      = [rhoDGP_px_ig_tau_ig    , rhoDGP_px_ig_taustar_ig   ];
    shareDGP_dW = gammaDGP*[shareDGP_pm; shareDGP_rm; shareDGP_px];
  
    shareDGP_dW_da     = [rhoDGP_pm_ig_a_ig    ; rhoDGP_rm_ig_a_ig    ; rhoDGP_px_ig_a_ig    ];
    shareDGP_dW_dzstar = [rhoDGP_pm_ig_zstar_ig; rhoDGP_rm_ig_zstar_ig; rhoDGP_px_ig_zstar_ig];
    shareDGP_dW_dastar = [rhoDGP_pm_ig_astar_ig; rhoDGP_rm_ig_astar_ig; rhoDGP_px_ig_astar_ig];   
 
    clear    rhoDGP* shareDGP_p* shareDGP_r* shareDGP_q* delta_ig_DGP0 a_star_ig_DGP0 z_star_ig_DGP0 a_ig_DGP0 aM_g_DGP0 barL_rs_DGP0 Z_rs_DGP0
   
    toc
    tic
    %Simulation across draws of shocks 
    parfor n=1:Nsimul

    % simulated data
        %simulated shifts
        dtau_sim = normrnd(mean_shifters, sd_shifters, NMs, 1);
        dtaustar_sim = normrnd(mean_shifters, sd_shifters, NXs, 1);  
        shifters = [dtau_sim; dtaustar_sim];
        shiftersIV = (shifters - mean(shifters))/std(shifters);
        
        %simulated shocks
        da_sim     = normrnd(mean_deta, sd_da    , NMs, 1);
        dzstar_sim = normrnd(mean_deta, sd_dzstar, NMs, 1);
        dastar_sim = normrnd(mean_deta, sd_dastar, NXs, 1);
        detaDGP = shareDGP_dW_da*da_sim + shareDGP_dW_dzstar*dzstar_sim + shareDGP_dW_dastar*dastar_sim ;

        %Simulated predictions and outcomes with true parameters
        dx     = shareMOD_dW*shifters;
        dxstar = shareDGP_dW*shifters ;
        dy = dxstar + detaDGP;
        dyt_NC = dy - dx;
        dyn_NC = dyt_NC(sample_naive);
        std_vec(:,n,j)  = [std(dy); std(dy(indM==1)); std(dy(indT==1)); std(dy(indX==1))];

        % Welfare statistics
        dW = ww_dW'*dx;    
        dWstar = ww_dW'*dxstar;
        Delta_W = dWstar - dW;
       
    % Tests with the true parameters
        
        %correlation-based tests
        correl_pred = corr([dy, dx ]);
        correl_pred = correl_pred(2,1);
        MSE_pred = mean( dyt_NC.^2 );
        [bt_slope,~,~,~,stats_slope] = regress(dy, [ones(Nobs,1), dx]);
        R2_pred = stats_slope(1);
        
        %naive -- no GE
        zn_wMC = share_wMC_nGE*shiftersIV;          
        [bt_zn_wMC, se_zn_wMC, rj_zn_wMC, r0_zn_wMC] = implement_test(dyn_NC, zn_wMC, share_wMC_nGE, shiftersIV, critical, 0);
        
        %welfare adjusted
        zw_wMC = share_wMC_dW_j*shiftersIV;  
        [bt_zw_wMC, se_zw_wMC, rj_zw_wMC, r0_zw_wMC] = implement_test(dyt_NC, zw_wMC, share_wMC_dW_j, shiftersIV, critical, 0);       
   
        %Output
       est_n = [dW, dWstar, Delta_W, correl_pred, R2_pred, MSE_pred, ...
               bt_zn_wMC, bt_zw_wMC, ...
               se_zn_wMC, se_zw_wMC, ...
               rj_zn_wMC, rj_zw_wMC, ...
               r0_zn_wMC, r0_zw_wMC];
    
    % save outcomes for step n
        ests(:,n,j) = est_n';
        
    end
    toc
    save(save_simulation_path + "output_Figs_2_3.mat", 'ests', 'tradestat_sample', 'sd_shifters_grid', 'sd_da_grid', 'gamma_mean_px_grid','gamma_mean_pm_grid', 'gamma_mean_rm_grid', 'Npar', 'min_imp', 'min_exp', 'Nsd', 'Ngamma', 'std_vec', 'trade_target');
end

%% Report results

load(save_simulation_path + "output_Figs_2_3.mat")

ests_stat=[];
for j = 1:Npar
    ests_j = ests(:,:,j);
   
     estimates_j = ests_j(1:8,:);
     SE   = [zeros(6,1); mean(ests_j(9:10,:),2)];
     rej  = [zeros(6,1); mean(ests_j(11:12,:),2)];
     rej0 = [zeros(6,1); mean(ests_j(13:14,:),2)];     
     stdy(:,j) = mean(std_vec(:,:,j),2);

    ests_stat(:,:,j)     = [mean(estimates_j,2), std(estimates_j,1,2), SE, rej0, rej];   
end


 %% Figure 2

w_min = -.06;
w_max = .06;

ylabel_size = 14;
legend_size = 18;
label_size = 20;
marker_size = 10;
axes_size = 18;

correl_version = 4;
ACDIV_version = 8;
ymin = -1 ;
ymax = 1;

%Varying sigma_e
figure('DefaultAxesFontSize',axes_size, 'Position', [10 10 600 600]);
plot_lines_correl = tiledlayout(1,1);

j0 = 1;
jf = Nsd;
x_axis = sd_da_grid(j0:jf);
correl_series =  squeeze(ests_stat(correl_version,1,j0:jf));
rej0_IV_series =  squeeze(ests_stat(ACDIV_version,4,j0:jf));

nexttile
plot(x_axis, correl_series, 'Marker','d','MarkerFaceColor', [0.5 0.5 0.5], 'MarkerSize', marker_size, 'Color', [0.5 0.5 0.5])
hold on 
plot(x_axis, rej0_IV_series, 'Marker','o','MarkerFaceColor','black', 'MarkerSize', marker_size, 'Color', 'black')
hold on 
yline(0, '-black', 'HandleVisibility','off')
ylim([ymin ymax])
legend('Corr(\Delta y, \Delta x)', 'H0 rejection rate at 5%, preferred IV', 'Location', 'southeast', 'FontSize', legend_size)
ylabel('', 'FontSize', ylabel_size)
xticks([0 0.1 0.2 0.3])
yticks([-1, -0.5, 0, 0.5, 1])
xlabel('St. Dev. of non-tariff shocks', 'FontSize', label_size)

%Varying misspecification parameter
figure('DefaultAxesFontSize',axes_size, 'Position', [10 10 600 600]);
plot_mis2_correl = tiledlayout(1,1);

j0 = Nsd+1;
jf = j0+Ngamma-1;
x_axis = squeeze(ests_stat(3,1,j0:jf)); 
correl_series =  squeeze(ests_stat(correl_version,1,j0:jf));
rej0_IV_series =  squeeze(ests_stat(ACDIV_version,4,j0:jf));

xmin = w_min;
xmax = w_max;

nexttile
plot(x_axis, correl_series, 'Marker','d','MarkerFaceColor', [0.5 0.5 0.5], 'MarkerSize', marker_size, 'Color', [0.5 0.5 0.5])
hold on 
plot(x_axis, rej0_IV_series, 'Marker','o','MarkerFaceColor','black', 'MarkerSize', marker_size, 'Color', 'black')
hold on 
yline(0, '-black', 'HandleVisibility','off')
xline(0, '-black', 'HandleVisibility','off')
ylim([ymin ymax])
xlim([xmin xmax])
xticks([-.06 -.04 -.02 0 .02 .04 .06])
yticks([-1, -0.5, 0, 0.5, 1])
legend('Corr(\Delta y, \Delta x)', 'H0 rejection rate at 5%, preferred IV', 'Location', 'southeast', 'FontSize', legend_size)
xlabel('E_t[W (\Delta x* ) - W (\Delta x )]', 'FontSize', label_size)

saveas(plot_lines_correl, graph_path+ "Fig_2a.png")
saveas(plot_mis2_correl, graph_path+ "Fig_2b.png")

saveas(plot_lines_correl, graph_path+ "Fig_2a.eps", 'epsc');
saveas(plot_mis2_correl, graph_path+ "Fig_2b.eps", 'epsc');


%% Figure 3

naive_version = 7;
ACDIV_version = 8;
j0 = Nsd+1;
jf = Nsd+Ngamma;
DeltaW = squeeze(ests_stat(3,1,j0:jf));

%Average estimated coefficient
figure('DefaultAxesFontSize',axes_size, 'Position', [10 10 600 600]);
plot_lines_avg = tiledlayout(1,1);

beta_naive_series =  squeeze(ests_stat(naive_version ,1,j0:jf));
beta_IV_series =  squeeze(ests_stat(ACDIV_version,1,j0:jf));
x_axis = DeltaW;

ymin = w_min;
ymax = w_max;
xmin = w_min;
xmax = w_max;
%line_45d = xmin:.02:xmax;

nexttile
plot(x_axis, beta_naive_series, 'Marker','d','MarkerFaceColor', [0.5 0.5 0.5], 'MarkerSize', marker_size, 'Color', [0.5 0.5 0.5])
hold on 
plot(x_axis, beta_IV_series, 'Marker','o','MarkerFaceColor','black', 'MarkerSize', marker_size, 'Color', 'black')
hold on 
%plot(line_45d, line_45d, '-','Color', 'black', 'LineWidth', 1)
%hold on
yline(0, '-black', 'HandleVisibility','off')
xline(0, '-black', 'HandleVisibility','off')
ylim([ymin ymax])
xlim([xmin xmax])
%legend('naive IV', 'preferred IV', '45^o line', 'Location', 'north', 'FontSize', legend_size)
legend('naive IV', 'preferred IV', 'Location', 'north', 'FontSize', legend_size)
xlabel('E_t[W (\Delta x* ) - W (\Delta x )]', 'FontSize', label_size)

%Rejection rate
figure('DefaultAxesFontSize',axes_size, 'Position', [10 10 600 600]);
plot_lines_rej = tiledlayout(1,1);

rej0_naive_series =  squeeze(ests_stat(naive_version ,4,j0:jf));
rej0_IV_series =  squeeze(ests_stat(ACDIV_version,4,j0:jf));
x_axis = DeltaW;

ymin = 0 ;
ymax = 1;

nexttile
plot(x_axis, rej0_naive_series, 'Marker','d','MarkerFaceColor', [0.5 0.5 0.5], 'MarkerSize', marker_size, 'Color', [0.5 0.5 0.5])
hold on 
plot(x_axis, rej0_IV_series, 'Marker','o','MarkerFaceColor','black', 'MarkerSize', marker_size, 'Color', 'black')
hold on
yline(0, '-black', 'HandleVisibility','off')
xline(0, '-black', 'HandleVisibility','off')
ylim([ymin ymax])
xlim([xmin xmax])
legend('naive IV', 'preferred IV', 'Location', 'north', 'FontSize', legend_size)
yticks([0 .2 .4 .6 .8 1])
xlabel('E_t[W (\Delta x* ) - W (\Delta x )]', 'FontSize', label_size)

saveas(plot_lines_rej, graph_path+ "Fig_3a.png")
saveas(plot_lines_avg, graph_path+ "Fig_3b.png")

saveas(plot_lines_rej, graph_path+ "Fig_3a.eps", 'epsc');
saveas(plot_lines_avg, graph_path+ "Fig_3b.eps", 'epsc');
